python, decision treeのtree graph plot¶
動機¶
- decision treeのtree plotのためにいろいろ調べた
- kerasで使うので久しぶりに調べた(2017-02-16)
- 平面で境界線のplotはmatplotlibで書くので、この記事を閉じてよい
scikit-learnでdecision tree¶
- http://scikit-learn.org/stable/modules/tree.html
- http://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html
plot用APIのWIP¶
https://github.com/scikit-learn/scikit-learn/pull/6380
library¶
pydot¶
- https://pypi.python.org/pypi/pydot
- https://pypi.python.org/pypi/pydot2/1.0.32
- https://pypi.python.org/pypi/pydot3
- https://pypi.python.org/pypi/pydot-ng
pydot2¶
- 上記でいろいろあげたけど、pydotを使うならこれでよさそう
- pydotと同じ作者
- source codeは公開されているけどrepositoryがない?
- sklearnのexampleにも使われているが、開発継続されていない?
- python3でやってみたけど、エラー
from sklearn.externals.six import StringIO as SkStringIO
from IPython.display import Image
import pydot
dot_data = SkStringIO()
export_graphviz(
    tree_clf, out_file=dot_data,
    feature_names=X.columns,
    class_names=["0", "1"],
    filled=True, rounded=True,
    special_characters=True)
graph = pydot.graph_from_dot_data(dot_data.getvalue())
Image(graph.create_png())
pydot3¶
- kerasでvisualizationを試したときに使った
- python3で動いた
- forkではなくportしてcompatibleにしたcommitをしている
- いろいろあるrepositoryの中で一番直近に更新されている
pydot-ng¶
- pydotのrepositoryのhistoryを引き継いでいる個人repositoryをforkしたもの
- organizationになっている
- import pydot_ng as pydotとしてmodule名の代用をしないといけないので既存moduleではだめっぽそう- https://github.com/fchollet/keras/commit/e5d3abdf09d8c281ca8817b6292a044673ba3007
- kerasではこのmoduleを優先して(pydotとして)importしているのでこれでよい
- なければpydotがimportされるのでpydot3でも大丈夫だった
 
networkx¶
- http://pypi.python.org/pypi/networkx/
- GraphVizは必須ではなくOption
PyGraphviz¶
- http://pygraphviz.github.io
- networkxと連携できる
graphviz¶
- https://pypi.python.org/pypi/graphviz
- plot用APIのWIPではこれが使われていた
- こちらは問題なく動いた
pip install graphviz
graphviz_tree_plot.py¶
from sklearn.externals.six import StringIO as SkStringIO
from IPython.display import Image
import graphviz
def tree_plot(decision_tree, width=500, height=500, max_depth=None,
         feature_names=None, class_names=None, label='all',
         filled=False, leaves_parallel=False, impurity=True,
         node_ids=False, proportion=False, rotate=False,
         rounded=False, special_characters=False):
    in_memory_dot_file = SkStringIO()
    export_graphviz(
            decision_tree, out_file=in_memory_dot_file, max_depth=max_depth,
            feature_names=feature_names, class_names=class_names, label=label,
            filled=filled, leaves_parallel=leaves_parallel, impurity=impurity,
            node_ids=node_ids, proportion=proportion, rotate=rotate,
            rounded=rounded, special_characters=special_characters)
    src = graphviz.Source(in_memory_dot_file.getvalue())
    return Image(src.pipe(format='png'), height=height, width=width)
tree_plot(tree_clf,
    feature_names=X.columns,
    class_names=["d", "s"],
#    class_names={1: "s", 0: "d"},
    filled=True,
    rounded=True,
    special_characters=True
)
比較¶
http://plaza.rakuten.co.jp/kugutsushi/diary/200711080000/